{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate a SWOT swath.\n",
    "\n",
    "This example lets us understand how to initialize the simulation parameters (error, SSH interpolation, orbit), generate an orbit, generate a swath, interpolate the SSH, and simulate the measurement errors. Finally, we visualize the simulated data.\n",
    "\n",
    "### Simulation setup\n",
    "\n",
    "The configuration is defined using an associative dictionary between the expected parameters and the values of its parameters. The description of the parameters is available on the [online help](https://swot-simulator.readthedocs.io/en/latest/generated/swot_simulator.settings.Parameters.html#swot_simulator.settings.Parameters).\n",
    "\n",
    "This array can be loaded from a Python file using the [eval_config_file](https://swot-simulator.readthedocs.io/en/latest/generated/swot_simulator.settings.eval_config_file.html#swot_simulator.settings.eval_config_file) method. But it can also be declared directly in Python code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import swot_simulator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "configuration = dict(\n",
    "    # The swath contains in its centre a central pixel divided in two by the\n",
    "    # reference ground track.\n",
    "    central_pixel=True,\n",
    "    # Distance, in km, between two points along track direction.\n",
    "    delta_al=2.0,\n",
    "    # Distance, in km, between two points across track direction.\n",
    "    delta_ac=2.0,\n",
    "    # Distance, in km, between the nadir and the center of the first pixel of the\n",
    "    # swath\n",
    "    half_gap=2.0,\n",
    "    # Distance, in km, between the nadir and the center of the last pixel of the\n",
    "    # swath\n",
    "    half_swath=70.0,\n",
    "    # Limits of SWOT swath requirements. Measurements outside the span will be\n",
    "    # set with fill values.\n",
    "    requirement_bounds=[10, 60],\n",
    "    # Ephemeris file to read containing the satellite's orbit.\n",
    "    ephemeris=swot_simulator.DATA.joinpath(\n",
    "        \"ephemeris_calval_june2015_ell.txt\"),\n",
    "    # Generation of measurement noise.\n",
    "    noise=[\n",
    "        'altimeter',\n",
    "        'baseline_dilation',\n",
    "        'karin',\n",
    "        'roll_phase',\n",
    "        #'orbital', (This generator consumes too much memory to run with binder)\n",
    "        'timing',\n",
    "        'wet_troposphere',\n",
    "    ],\n",
    "    # File containing spectrum of instrument error\n",
    "    error_spectrum=swot_simulator.DATA.joinpath(\"error_spectrum.nc\"),\n",
    "    # KaRIN file containing spectrum for several SWH\n",
    "    karin_noise=swot_simulator.DATA.joinpath(\"karin_noise_v2.nc\"),\n",
    "    # The plug-in handling the SSH interpolation under the satellite swath.\n",
    "    #ssh_plugin = TODO\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We create the parameter object for our simulation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import swot_simulator.settings\n",
    "\n",
    "parameters = swot_simulator.settings.Parameters(configuration)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "**Note**\n",
    "\n",
    "The [Parameter](https://swot-simulator.readthedocs.io/en/latest/generated/swot_simulator.settings.Parameters.html#swot_simulator.settings.Parameters) class exposes the [load_default](https://swot-simulator.readthedocs.io/en/latest/generated/swot_simulator.settings.Parameters.load_default.html#swot_simulator.settings.Parameters.load_default) method returning the default parameters of the simulation:\n",
    "\n",
    "```python\n",
    "parameters = swot_simulator.settings.Parameters.load_default()\n",
    "```\n",
    "\n",
    "It is also possible to [automatically load the\n",
    "dictionary](https://swot-simulator.readthedocs.io/en/latest/generated/swot_simulator.settings.template.html)\n",
    "containing the simulation parameters to adapt them to your needs and after all\n",
    "create the parameters of your simulation.\n",
    "\n",
    "```python\n",
    "configuration = swot_simulator.settings.template(python=True)\n",
    "parameters = swot_simulator.settings.Parameters(configuration)\n",
    "```\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SSH interpolation\n",
    "\n",
    "The written configuration allows us to simulate a swath. However, the interpolation of the SSH under the satellite swath remains undefined. If you don't need this parameter, you can skip this setting.\n",
    "\n",
    "For our example, we use the SSH of the CMEMS grids provided on the Pangeo site."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import intake\n",
    "\n",
    "cat = intake.open_catalog(\"https://raw.githubusercontent.com/pangeo-data/\"\n",
    "                          \"pangeo-datastore/master/intake-catalogs/master.yaml\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = cat.ocean.sea_surface_height.to_dask()\n",
    "ds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To interpolate SSH, we need to implement a class that must define a method to\n",
    "interpolate the data under the swath. This class must be derived from the\n",
    "[CartesianGridHandler](https://swot-simulator.readthedocs.io/en/latest/generated/swot_simulator.plugins.data_handler.CartesianGridHandler.html)\n",
    "class to be correctly taken into account by the class managing the parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pyinterp.backends.xarray\n",
    "import numpy\n",
    "import xarray\n",
    "#\n",
    "import swot_simulator.plugins.data_handler\n",
    "\n",
    "\n",
    "class CMEMS(swot_simulator.plugins.data_handler.CartesianGridHandler):\n",
    "    \"\"\"\n",
    "    Interpolation of the SSH AVISO (CMEMS L4 products).\n",
    "    \"\"\"\n",
    "    def __init__(self, adt):\n",
    "        self.adt = adt\n",
    "        ts = adt.time.data\n",
    "\n",
    "        assert numpy.all(ts[:-1] <= ts[1:])\n",
    "\n",
    "        # The frequency between the grids must be constant.\n",
    "        frequency = set(numpy.diff(ts.astype(\"datetime64[s]\").astype(\"int64\")))\n",
    "        if len(frequency) != 1:\n",
    "            raise RuntimeError(\n",
    "                \"Time series does not have a constant step between two \"\n",
    "                f\"grids: {frequency} seconds\")\n",
    "\n",
    "        # The frequency is stored in order to load the grids required to\n",
    "        # interpolate the SSH.\n",
    "        self.dt = numpy.timedelta64(frequency.pop(), 'ns')\n",
    "\n",
    "    def load_dataset(self, first_date, last_date):\n",
    "        \"\"\"Loads the 3D cube describing the SSH in time and space.\"\"\"\n",
    "        if first_date < self.adt.time[0] or last_date > self.adt.time[-1]:\n",
    "            raise IndexError(\n",
    "                f\"period [{first_date}, {last_date}] is out of range: \"\n",
    "                f\"[{self.adt.time[0]}, {self.adt.time[-1]}]\")\n",
    "        first_date = self.adt.time.sel(time=first_date, method='pad')\n",
    "        last_date = self.adt.time.sel(time=last_date, method='backfill')\n",
    "        selected = self.adt.loc[dict(time=slice(first_date, last_date))]\n",
    "        selected = selected.compute()\n",
    "        return pyinterp.backends.xarray.Grid3D(selected.adt)\n",
    "\n",
    "    def interpolate(self, lon, lat, time):\n",
    "        \"\"\"Interpolate the SSH to the required coordinates\"\"\"\n",
    "        interpolator = self.load_dataset(time.min(), time.max())\n",
    "        ssh = interpolator.trivariate(dict(longitude=lon,\n",
    "                                           latitude=lat,\n",
    "                                           time=time),\n",
    "                                      interpolator='bilinear')\n",
    "        return ssh"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can update our parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "parameters.ssh_plugin = CMEMS(ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initiating orbit propagator.\n",
    "\n",
    "Initialization is simply done by [loading](https://swot-simulator.readthedocs.io/en/latest/generated/swot_simulator.orbit_propagator.load_ephemeris.html#swot_simulator.orbit_propagator.load_ephemeris) the ephemeris file. The satellite's one-day pass is taken into account in this case."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import swot_simulator.orbit_propagator\n",
    "\n",
    "\n",
    "with open(parameters.ephemeris, \"r\") as stream:\n",
    "    orbit = swot_simulator.orbit_propagator.calculate_orbit(parameters, stream)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Iterate on the half-orbits of a period.\n",
    "\n",
    "To iterate over all the half-orbits of a period, call the method [iterate](https://swot-simulator.readthedocs.io/en/latest/generated/swot_simulator.orbit_propagator.Orbit.iterate.html#swot_simulator.orbit_propagator.Orbit.iterate).  This method returns all cycle numbers, trace numbers, and start dates of the half orbits within the period. If the start date remains not set, the method uses the current date. If the end date remains undefined, the method sets the end date to the start date plus the cycle duration.\n",
    "\n",
    "In our case, we generate a cycle from January 1, 2000."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "first_date = numpy.datetime64(\"2000-01-01\")\n",
    "iterator = orbit.iterate(first_date)\n",
    "cycle_number, pass_number, date = next(iterator)\n",
    "cycle_number, pass_number, date"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialization of measurement error generators\n",
    "\n",
    "Error initialization is done simply by calling the appropriate [class](https://swot-simulator.readthedocs.io/en/latest/generated/swot_simulator.error.generator.Generator.html#swot_simulator.error.generator.Generator). The initialization of the wet troposphere error generator takes a little time (about 40 seconds), which explains the processing time for the next cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import swot_simulator.error.generator\n",
    "\n",
    "error_generator = swot_simulator.error.generator.Generator(\n",
    "    parameters, first_date, orbit.orbit_duration())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate the positions under the swath.\n",
    "\n",
    "To perform this task, the following function is implemented.\n",
    "\n",
    "> If the position of the pass is outside the area of interest (`parameters.area`),\n",
    "> the generation of the pass can return `None`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_one_track(pass_number, date, orbit):\n",
    "    # Compute the spatial/temporal position of the satellite\n",
    "    track = swot_simulator.orbit_propagator.calculate_pass(\n",
    "        pass_number, orbit, parameters)\n",
    "\n",
    "    # If the pass is not located in the area of interest (parameter.area)\n",
    "    # the result of the generation can be null.\n",
    "    if track is None:\n",
    "        return None\n",
    "\n",
    "    # Set the simulated date\n",
    "    track.set_simulated_date(date)\n",
    "\n",
    "    return track"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Interpolate SSH\n",
    "\n",
    "Interpolation of the SSH for the space-time coordinates generated by the simulator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def interpolate_ssh(parameters, track):\n",
    "    swath_time = numpy.repeat(track.time,\n",
    "                              track.lon.shape[1]).reshape(track.lon.shape)\n",
    "    ssh = parameters.ssh_plugin.interpolate(track.lon.ravel(),\n",
    "                                            track.lat.ravel(),\n",
    "                                            swath_time.ravel())\n",
    "    return ssh.reshape(track.lon.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Calculation of instrumental errors\n",
    "\n",
    "Simulation of instrumental errors. \n",
    "\n",
    "> Karin's instrumental noise can be modulated by wave heights.\n",
    "> The parameter SWH takes either a constant or a matrix defining\n",
    "> the SWH for the swath positions. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_instrumental_errors(error_generator, cycle_number, pass_number,\n",
    "                                 orbit, track):\n",
    "    return error_generator.generate(cycle_number,\n",
    "                                    pass_number,\n",
    "                                    orbit.curvilinear_distance,\n",
    "                                    track.time,\n",
    "                                    track.x_al,\n",
    "                                    track.x_ac,\n",
    "                                    swh=2.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Calculates the sum of the simulated errors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sum_error(errors, swath=True):\n",
    "    \"\"\"Calculate the sum of errors\"\"\"\n",
    "    dims = 2 if swath else 1\n",
    "    return numpy.add.reduce(\n",
    "        [item for item in errors.values() if len(item.shape) == dims])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ### Create the swath dataset\n",
    " \n",
    " Generation of the simulated swath. The function returns an xarray dataset for the half-orbit generated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import swot_simulator.netcdf\n",
    "\n",
    "\n",
    "def generate_dataset(cycle_number,\n",
    "                     pass_number,\n",
    "                     track,\n",
    "                     ssh,\n",
    "                     noise_errors,\n",
    "                     complete_product=False):\n",
    "    product = swot_simulator.netcdf.Swath(track, central_pixel=True)\n",
    "    # Mask to set the measurements outside the requirements of the mission to\n",
    "    # NaN.\n",
    "    mask = track.mask()\n",
    "    ssh *= mask\n",
    "    product.ssh(ssh + sum_error(noise_errors))\n",
    "    product.simulated_true_ssh(ssh)\n",
    "    for error in noise_errors.values():\n",
    "        # Only the swaths must be masked\n",
    "        if len(error.shape) == 2:\n",
    "            error *= mask\n",
    "    product.update_noise_errors(noise_errors)\n",
    "    return product.to_xarray(cycle_number, pass_number, complete_product)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Swath generation.\n",
    "\n",
    "Now we can combine the different components to generate the swath."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dask\n",
    "import dask.distributed\n",
    "\n",
    "\n",
    "def generate_swath(cycle_number, pass_number, date, parameters,\n",
    "                   error_generator, orbit):\n",
    "    client = dask.distributed.get_client()\n",
    "    # Scatter big data\n",
    "    orbit_ = client.scatter(orbit)\n",
    "    error_generator_ = client.scatter(error_generator)\n",
    "    # Compute swath positions\n",
    "    track = dask.delayed(generate_one_track)(pass_number, date, orbit_)\n",
    "    # Interpolate SSH\n",
    "    ssh = dask.delayed(interpolate_ssh)(parameters, track)\n",
    "    # Simulate instrumental errors\n",
    "    noise_errors = dask.delayed(generate_instrumental_errors)(error_generator_,\n",
    "                                                              cycle_number,\n",
    "                                                              pass_number,\n",
    "                                                              orbit_, track)\n",
    "    # Finally generate the dataset\n",
    "    return dask.delayed(generate_dataset)(\n",
    "        cycle_number, pass_number, track, ssh, noise_errors,\n",
    "        parameters.complete_product).compute()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The simulator calculation can be distributed on a Dask cluster."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dask.distributed\n",
    "\n",
    "# A local cluster is used here.\n",
    "cluster = dask.distributed.LocalCluster()\n",
    "client = dask.distributed.Client(cluster)\n",
    "client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "error_generator_ = client.scatter(error_generator)\n",
    "parameters_ = client.scatter(parameters)\n",
    "orbit_ = client.scatter(orbit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "future = client.submit(generate_swath, cycle_number, pass_number, date,\n",
    "                       parameters_, error_generator_, orbit_)\n",
    "ds = client.gather(future)\n",
    "ds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To calculate a trace set you can use the following code\n",
    "\n",
    "    futures = []\n",
    "    for cycle_number, pass_number, date in iterator:\n",
    "        futures.append(client.submit(generate_swath,\n",
    "                                     cycle_number,\n",
    "                                     pass_number,\n",
    "                                     date,\n",
    "                                     error_generator_,\n",
    "                                     orbit_,\n",
    "                                     parameters))\n",
    "    client.gather(futures)\n",
    "    \n",
    "### Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot\n",
    "import cartopy.crs\n",
    "import cartopy.feature\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Selection of a reduced geographical area for visualization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "selected = ds.where((ds.latitude > -50) & (ds.latitude < -40), drop=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Simulated SSH measurements (Interpolated SSH and simulated instrumental errors)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = matplotlib.pyplot.figure(figsize=(24, 12))\n",
    "ax = fig.add_subplot(1, 1, 1, projection=cartopy.crs.PlateCarree())\n",
    "contourf = ax.contourf(selected.longitude,\n",
    "                       selected.latitude,\n",
    "                       selected.ssh_karin,\n",
    "                       transform=cartopy.crs.PlateCarree(),\n",
    "                       levels=255,\n",
    "                       cmap='jet')\n",
    "fig.colorbar(contourf, orientation='vertical')\n",
    "ax.set_extent([60, 69, -50, -40], crs=cartopy.crs.PlateCarree())\n",
    "ax.gridlines(draw_labels=True, dms=True, x_inline=False, y_inline=False)\n",
    "ax.coastlines()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Simulated KaRIN instrumental noise."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for item in selected.variables:\n",
    "    if item.startswith(\"simulated_error\"):\n",
    "        variable = selected.variables[item]\n",
    "        fig = matplotlib.pyplot.figure(figsize=(18, 8))\n",
    "        ax = fig.add_subplot(1, 1, 1)\n",
    "        image = ax.imshow(variable.T,\n",
    "                          extent=[0, len(selected.num_lines), -70, 70],\n",
    "                          cmap='jet')\n",
    "        ax.set_title(variable.attrs['long_name'] + \"(\" +\n",
    "                     variable.attrs['units'] + \")\")\n",
    "        ax.set_xlabel(\"num_lines\")\n",
    "        ax.set_ylabel(\"num_pixels\")\n",
    "        fig.colorbar(image,\n",
    "                     orientation='vertical',\n",
    "                     fraction=0.046 * 70 / 250,\n",
    "                     pad=0.04)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "b5abb4912561cb3944b8ab88aeb89e73503903f2213247e02c73ef25ab3e9321"
  },
  "kernelspec": {
   "display_name": "Python 3.8.10 64-bit ('conda-forge': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
